"""
将模型参数匹配到Pytorch1.6以下的版本
1.6之后pytorch默认使用zip文件格式来保存模型参数, 导致这些文件无法直接被1.5及以下的pytorch加载

在pytorch1.6以上版本重新保存参数(非zip格式)如下,这些参数将能兼容1.6以下的版本
"""
import torch

dir_name=['FCN','Segnet','Unet','MultiCapsule','Capsule']

print('start fit pth...')

for i in range(len(dir_name)):
    state_dict = torch.load('../checkpoint/{}/model/netG_final.pth'.format(dir_name[i]),
                            map_location='cpu')
    torch.save(state_dict,
               '../checkpoint/{}/model/netG_final.pth'.format(dir_name[i]),
               _use_new_zipfile_serialization=False)

    for j in range(10):
        state_dict = torch.load('../checkpoint/{}/model/netG_{}.pth'.format(dir_name[i],j+1),
                                map_location='cpu')
        torch.save(state_dict,
                   '../checkpoint/{}/model/netG_{}.pth'.format(dir_name[i],j+1),
                   _use_new_zipfile_serialization=False)

print('fit pth done!')